{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JyWknuJRdhK8"
      },
      "source": [
        "# MPLP: MAML Sinusoidal 2 step, 10 shot\n",
        "\n",
        "\n",
        "Copyright 2020 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "c6GXN11wdbvU"
      },
      "outputs": [],
      "source": [
        "# @title Connect to internal TF kernel and run this.\n",
        "import os\n",
        "import io\n",
        "import numpy as np\n",
        "import glob\n",
        "\n",
        "import tensorflow.compat.v2 as tf\n",
        "tf.enable_v2_behavior()\n",
        "tf.get_logger().setLevel('ERROR')\n",
        "\n",
        "import matplotlib.pyplot as plt # visualization\n",
        "\n",
        "from collections import defaultdict\n",
        "import random\n",
        "\n",
        "import itertools\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow.compat.v2 as tf\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import IPython.display as display\n",
        "from IPython.display import clear_output\n",
        "\n",
        "from PIL import Image\n",
        "import numpy as np\n",
        "import os\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QBo8SVVqmLGU"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=mplp\u0026subdirectory=mplp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cAm7UBfbjJz6"
      },
      "outputs": [],
      "source": [
        "# symlink for saved models.\n",
        "!ln -s src/mplp/mplp/savedmodels savedmodels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Rwc7l57-HZaI"
      },
      "outputs": [],
      "source": [
        "from mplp.tf_layers import MPDense\n",
        "from mplp.tf_layers import MPActivation\n",
        "from mplp.tf_layers import MPSoftmax\n",
        "from mplp.tf_layers import MPL2Loss\n",
        "from mplp.tf_layers import MPNetwork\n",
        "from mplp.sinusoidals import SinusoidalsDS\n",
        "from mplp.util import SamplePool\n",
        "from mplp.training import TrainingRegime"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qKEOL1bJp3XQ"
      },
      "source": [
        "The task is to fit sinusoidals from randomly initialized networks.\n",
        "\n",
        "Therefore, there are:\n",
        "* Outer batch size = 4 number of tasks at every step. Each has a different network, different amplitude and different phase.\n",
        "* Inner batch size = 10 number of examples for each forward/backward steps.\n",
        "* num steps = 5, number of inner steps the network has to get better.\n",
        "* train/eval split: the network only sees train instances during forward/backward. The meta-learning regime *may* choose to use eval splits as well, MAML-style."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "Yx3ubhzRgPHh"
      },
      "outputs": [],
      "source": [
        "# @title create dataset and plot it\n",
        "OUTER_BATCH_SIZE = 4\n",
        "INNER_BATCH_SIZE = 10\n",
        "NUM_STEPS = 2\n",
        "\n",
        "ds_factory = SinusoidalsDS()\n",
        "\n",
        "ds = ds_factory.create_ds(OUTER_BATCH_SIZE, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "ds_iter = iter(ds)\n",
        "\n",
        "# Utility range\n",
        "xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)).astype(np.float32)\n",
        "\n",
        "xtb, ytb, xeb, yeb = next(ds_iter)\n",
        "plt.figure(figsize=(14, 10))\n",
        "colors = itertools.cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])\n",
        "for xts, yts, xes, yes in zip(xtb, ytb, xeb, yeb):\n",
        "  c_t = next(colors)\n",
        "  c_e = next(colors)\n",
        "  markers = itertools.cycle((',', '+', '.', 'o', '*')) \n",
        "  for xtsib, ytsib, xesib, yesib in zip(xts, yts, xes, yes):\n",
        "    marker = next(markers)\n",
        "    plt.scatter(xtsib, ytsib, c=c_t, marker=marker)\n",
        "    plt.scatter(xesib, yesib, c=c_e, marker=marker)\n",
        "\n",
        "plt.show()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Aq5JC810rXYF"
      },
      "source": [
        "Create a MP network:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "lD0T9QJFnEaz"
      },
      "outputs": [],
      "source": [
        "\n",
        "# This is the size of the message passed.\n",
        "message_size = 8\n",
        "stateful = False\n",
        "stateful_hidden_n = 15\n",
        "\n",
        "# This network is keras-style initialized.\n",
        "# If you want to create a single layer, you need to pass it also the in_dim\n",
        "# and message size.\n",
        "network = MPNetwork(\n",
        "    [\n",
        "     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), \n",
        "     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     ],\n",
        "     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))\n",
        "network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)\n",
        "\n",
        "# see trainable weights:\n",
        "tr_w = network.get_trainable_weights()\n",
        "print(\"trainable weights:\")\n",
        "tot_w = 0\n",
        "for w in tr_w:\n",
        "  w = w.numpy()\n",
        "  w_size = w.size\n",
        "  tot_w += w_size\n",
        "\n",
        "  print(w.shape, w_size)\n",
        "print(\"tot n:\", tot_w)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "776DRprNBM4p"
      },
      "outputs": [],
      "source": [
        "# for MAML training, we need one and only one set of variables.\n",
        "trained_pfw = [tf.Variable(t) for t in network.init()]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Xv5JvNrLaDya"
      },
      "outputs": [],
      "source": [
        "num_steps = tf.constant(NUM_STEPS)\n",
        "\n",
        "learning_schedule = 1e-4\n",
        "\n",
        "training_regime = TrainingRegime(\n",
        "    network, heldout_weight=1.0, hint_loss_ratio=None, remember_loss_ratio=None)\n",
        "\n",
        "last_step = 0\n",
        "\n",
        "# minibatch init, allowing to initialize by looking at more\n",
        "# than just one step.\n",
        "# Likewise, this can be run multiple times to improve the initialization.\n",
        "for j in range(1):\n",
        "  print(\"on\", j)\n",
        "  stats = []\n",
        "  pfw = trained_pfw\n",
        "\n",
        "  x_b, y_b, _, _ = next(ds_iter)\n",
        "  x_b, y_b = x_b[0], y_b[0]\n",
        "  for i in range(NUM_STEPS):\n",
        "    pfw, stats_i = network.minibatch_init(x_b[i],  y_b[i], x_b[i].shape[0], pfw=pfw)\n",
        "    stats.append(stats_i)\n",
        "  # update\n",
        "  network.update_statistics(stats, update_perc=1.)\n",
        "\n",
        "  print(\"final mean:\")\n",
        "  for p in tf.nest.flatten(pfw):\n",
        "    print(p.shape, tf.reduce_mean(p), tf.math.reduce_std(p))\n",
        "\n",
        "\n",
        "\n",
        "# The outer loop here uses Adam. SGD/Momentum are more stable but way slower.\n",
        "trainer = tf.keras.optimizers.Adam(learning_schedule)\n",
        "\n",
        "loss_log = []\n",
        "def smoothen(l, lookback=20):\n",
        "  # first of all, if it's a nan, change it to a high value\n",
        "  kernel = [1./lookback] * lookback\n",
        "  return np.convolve(l[0:1] * (lookback - 1) + l, kernel, \"valid\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "eVYI6NAcLj1M"
      },
      "outputs": [],
      "source": [
        "print([p.shape for p in trained_pfw])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "sWlFhwL5od7q"
      },
      "outputs": [],
      "source": [
        "training_steps = 200000\n",
        "print(\"Stop this block whenever after 1-2k steps. It's good even very early.\")\n",
        "\n",
        "@tf.function\n",
        "def step(pfw, xts, yts, xes, yes, num_steps):\n",
        "  print(\"compiling\")\n",
        "  with tf.GradientTape() as g:\n",
        "    pfw_serialized = network.serialize_pfw(pfw)\n",
        "    l, _, _ = training_regime.batch_mp_loss(\n",
        "        pfw_serialized, xts, yts, xes, yes, num_steps, same_pfw=True)\n",
        "  all_weights = network.get_trainable_weights()\n",
        "  all_weights += pfw\n",
        "  grads = g.gradient(l, all_weights)\n",
        "  # Try grad clipping to avoid explosions.\n",
        "  grads = [g/(tf.norm(g)+1e-8) for g in grads]\n",
        "  trainer.apply_gradients(zip(grads, all_weights))\n",
        "  return l\n",
        "\n",
        "\n",
        "import time\n",
        "start_time = time.time()\n",
        "\n",
        "for i in range(last_step + 1, last_step +1 + training_steps):\n",
        "  last_step = i\n",
        "\n",
        "  tmp_t = time.time()\n",
        "  xts, yts, xes, yes = next(ds_iter)\n",
        "\n",
        "  l = step(trained_pfw, xts, yts, xes, yes, num_steps)\n",
        "  loss_log.append(l)\n",
        "\n",
        "  if i % 50 == 0:\n",
        "    print(i)\n",
        "    print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "  if i % 500 == 0:\n",
        "    plt.plot(smoothen(loss_log, 100), label='mp')\n",
        "    plt.yscale('log')\n",
        "    #plt.ylim(0.0, 1e-1)\n",
        "    plt.legend()\n",
        "    plt.show()\n",
        "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QUmKqMWxqLtI"
      },
      "outputs": [],
      "source": [
        "\n",
        "print(loss_log[-1])\n",
        "plt.plot(smoothen(loss_log, 100), label='mp')\n",
        "plt.yscale('log')\n",
        "plt.ylim(0.0, 1e-1)\n",
        "plt.gca().yaxis.grid(True)\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tf4ENAZs-zP5"
      },
      "source": [
        "#Proper evaluation: run 100 different few-shot instances with totally new network params.\n",
        "\n",
        "The train loss is computed only on points that the network has already observed.\n",
        "\n",
        "The eval loss is computed on the entire range [-5, 5]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "SFbFuJ_GDBU5"
      },
      "outputs": [],
      "source": [
        "!mkdir tmp\n",
        "\n",
        "!ls tmp -R"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "NI47xog03QCg"
      },
      "outputs": [],
      "source": [
        "\n",
        "!mkdir tmp\n",
        "file_path = \"tmp/maml_sin_net_weights\"\n",
        "\n",
        "network.save_weights(file_path, last_step)\n",
        "\n",
        "with open(\"tmp/maml_sin_prior_weights_{:08d}.npy\".format(\n",
        "    last_step), \"wb\") as fout:\n",
        "  prior_to_save = tf.concat([tf.reshape(e, [-1]) for e in trained_pfw], 0)\n",
        "  np.save(fout, prior_to_save.numpy())\n",
        "\n",
        "!ls -lh tmp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "MPJlM6koyeUN"
      },
      "outputs": [],
      "source": [
        "# try to save and load\n",
        "file_path = \"savedmodels/maml_sin_net_weights\"\n",
        "network = MPNetwork(\n",
        "    [\n",
        "     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), \n",
        "     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),\n",
        "     ],\n",
        "     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))\n",
        "network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)\n",
        "network.load_weights(file_path)\n",
        "\n",
        "# Load prior weights too\n",
        "import tensorflow.io.gfile as gfile\n",
        "matcher = \"savedmodels/maml_sin_prior_weights_*.npy\"\n",
        "filenames = sorted(gfile.glob(matcher), reverse=True)\n",
        "assert len(filenames) \u003e 0, \"No files matching {}\".format(matcher)\n",
        "filename = filenames[0]\n",
        "print(filename)\n",
        "\n",
        "# load array\n",
        "with gfile.GFile(filename, \"rb\") as fin:\n",
        "  serialized_weights = np.load(fin)\n",
        "\n",
        "trained_pfw_shapes = [t.shape for t in network.init()]\n",
        "trained_pfw_flat_sizes = [int(tf.reshape(t, [-1]).shape[0]) for t in network.init()]\n",
        "\n",
        "print(serialized_weights.shape, trained_pfw_flat_sizes)\n",
        "all_weights_flat_split = tf.split(serialized_weights,\n",
        "                                  trained_pfw_flat_sizes)\n",
        "trained_pfw = [tf.reshape(t, s) for t, s in zip(\n",
        "    all_weights_flat_split, trained_pfw_shapes)]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Ty5_TQFm8vGu"
      },
      "outputs": [],
      "source": [
        "\n",
        "eval_tot_steps = 100\n",
        "\n",
        "tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])\n",
        "ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0-step.\n",
        "\n",
        "@tf.function\n",
        "def get_loss(pfw, x, y):\n",
        "  predictions, _= network.forward(pfw, x)\n",
        "  loss, _ = network.compute_loss(predictions, y)\n",
        "  return loss\n",
        "\n",
        "start_time = time.time()\n",
        "\n",
        "for r in range(eval_tot_steps):\n",
        "  p_fw = trained_pfw\n",
        "\n",
        "  A, ph = ds_factory._create_task()\n",
        "\n",
        "  targets = A * np.sin(xrange_inputs + ph)\n",
        "\n",
        "  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "\n",
        "  # initial loss.\n",
        "  loss = get_loss(p_fw, xrange_inputs, targets)\n",
        "  ev_losses[r, 0] = tf.reduce_mean(loss)\n",
        "\n",
        "  for i in range(NUM_STEPS):\n",
        "    p_fw, _= network.inner_update(p_fw, xt[i], yt[i])\n",
        "\n",
        "    # loss specific to only what we observe.\n",
        "    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "    loss = get_loss(p_fw, x_observed_so_far, y_observed_so_far)\n",
        "    tr_losses[r, i] = tf.reduce_mean(loss)\n",
        "\n",
        "    # Plotting for the continuous input range\n",
        "    loss = get_loss(p_fw, xrange_inputs, targets)\n",
        "    ev_losses[r, i + 1] = tf.reduce_mean(loss)\n",
        "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "\n",
        "tr_losses_m = np.mean(tr_losses, axis=0)\n",
        "ev_losses_m = np.mean(ev_losses, axis=0)\n",
        "\n",
        "tr_losses_sd = np.std(tr_losses, axis=0)\n",
        "ev_losses_sd = np.std(ev_losses, axis=0)\n",
        "\n",
        "print(\"tr_l, m:\", tr_losses_m, \" sd:\", tr_losses_sd)\n",
        "print(\"ev_l, m:\", ev_losses_m, \" sd:\", ev_losses_sd)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]\n",
        "plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)\n",
        "plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')\n",
        "plt.ylim(0.0, 0.025)\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"L2 loss\")\n",
        "plt.legend()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "czm5dzijZE2V"
      },
      "outputs": [],
      "source": [
        "print(tr_losses_m, ev_losses_m)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mV-JovFUs1V1"
      },
      "outputs": [],
      "source": [
        "# @title Show an example run:\n",
        "\n",
        "fig, axs = plt.subplots(5, 2, figsize=(10,15))\n",
        "\n",
        "for fig_n in range(5):\n",
        "  p_fw = trained_pfw\n",
        "\n",
        "  n_plot = 5\n",
        "  plot_every = 1\n",
        "\n",
        "  predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "\n",
        "  A, ph = ds_factory._create_task()\n",
        "\n",
        "  targets = A * np.sin(xrange_inputs + ph)\n",
        "  axs[fig_n][0].plot(xrange_inputs, targets, label='target')\n",
        "\n",
        "  predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "  axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))\n",
        "\n",
        "  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "  tr_losses = []\n",
        "  ev_losses = []\n",
        "\n",
        "  for i in range(NUM_STEPS):\n",
        "    p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])\n",
        "\n",
        "    # loss specific to only what we observe.\n",
        "    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "    predictions, _= network.forward(p_fw, x_observed_so_far)\n",
        "    loss, _ = network.compute_loss(predictions, y_observed_so_far)\n",
        "    tr_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "    # Plotting for the continuous input range\n",
        "    predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "    if (i+1) % plot_every == 0:\n",
        "      axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))\n",
        "    loss, _ = network.compute_loss(predictions, targets)\n",
        "    ev_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "  axs[fig_n][1].plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')\n",
        "  axs[fig_n][1].plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')\n",
        "\n",
        "\n",
        "axs[0][0].legend()\n",
        "axs[0][1].legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "0Lw7YaJaJfvB"
      },
      "outputs": [],
      "source": [
        "# @title Single run for drawing.\n",
        "\n",
        "p_fw = trained_pfw\n",
        "\n",
        "plot_every = 1\n",
        "\n",
        "predictions, _ = network.forward(p_fw, xrange_inputs)\n",
        "\n",
        "A, ph = ds_factory._create_task()\n",
        "\n",
        "targets = A * np.sin(xrange_inputs + ph)\n",
        "plt.plot(xrange_inputs, targets, label='target')\n",
        "\n",
        "predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))\n",
        "\n",
        "xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "tr_losses = []\n",
        "ev_losses = []\n",
        "\n",
        "for i in range(NUM_STEPS):\n",
        "  p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])\n",
        "\n",
        "  # loss specific to only what we observe.\n",
        "  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "  predictions, _dh= network.forward(p_fw, x_observed_so_far)\n",
        "  loss, _ = network.compute_loss(predictions, y_observed_so_far)\n",
        "  tr_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "  # Plotting for the continuous input range\n",
        "  predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "  if (i+1) % plot_every == 0:\n",
        "    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))\n",
        "  loss, _ = network.compute_loss(predictions, targets)\n",
        "  ev_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "\n",
        "plt.legend()\n",
        "\n",
        "\n",
        "with open(\"tmp/mplp_example_run.png\", \"wb\") as fout:\n",
        "  plt.savefig(fout)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1ZnWudcy_qUV"
      },
      "source": [
        "# Compare it with MAML run"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "dzCGXHHAT2FB"
      },
      "outputs": [],
      "source": [
        "maml_pfw = [tf.Variable(t) for t in network.init()]\n",
        "maml_last_step = 0\n",
        "maml_loss_log = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "UeGe4aaOKJ6r"
      },
      "outputs": [],
      "source": [
        "training_steps = 200000\n",
        "print(\"Stop this block whenever after 1-2k steps. It's good even very early.\")\n",
        "\n",
        "def update_pfw(pfw, xt, yt, num_steps):\n",
        "  for i in tf.range(num_steps):\n",
        "    with tf.GradientTape() as g:\n",
        "      g.watch(pfw)\n",
        "      prediction, _ = network.forward(pfw, xt[i])\n",
        "      loss, _ = network.compute_loss(prediction, yt[i])\n",
        "      loss = tf.reduce_mean(loss)\n",
        "    grads = g.gradient(loss, pfw)\n",
        "    \n",
        "    pfw = [p - 0.05 * pg for p, pg in zip(pfw, grads)]\n",
        "  return pfw\n",
        "\n",
        "def single_loss(pfw, xt, yt, xe, ye, num_steps):\n",
        "  new_pfw = update_pfw(pfw, xt, yt, num_steps)\n",
        "\n",
        "  prediction, _ = network.forward(new_pfw, xe)\n",
        "  cv_loss, _ = network.compute_loss(prediction, ye)\n",
        "  cv_loss = tf.reduce_mean(cv_loss)\n",
        "  return cv_loss\n",
        "\n",
        "def batch_maml_loss(pfw, xts, yts, xes, yes, num_steps):\n",
        "  task_losses = []\n",
        "  for i in range(len(xts)):\n",
        "    task_losses.append(\n",
        "        single_loss(pfw, xts[i], yts[i], xes[i], yes[i], num_steps))\n",
        "  return tf.reduce_mean(tf.stack(task_losses))\n",
        "\n",
        "@tf.function\n",
        "def maml_step(pfw, xts, yts, xes, yes, num_steps):\n",
        "  print(\"compiling\")\n",
        "  with tf.GradientTape() as g:\n",
        "    l = batch_maml_loss(pfw, xts, yts, xes, yes, num_steps)\n",
        "  grads = g.gradient(l, pfw)\n",
        "  # Try grad clipping to avoid explosions.\n",
        "  grads = [g/(tf.norm(g)+1e-8) for g in grads]\n",
        "  trainer.apply_gradients(zip(grads, pfw))\n",
        "  return l\n",
        "\n",
        "\n",
        "import time\n",
        "start_time = time.time()\n",
        "\n",
        "for i in range(maml_last_step + 1, maml_last_step +1 + training_steps):\n",
        "  maml_last_step = i\n",
        "\n",
        "  tmp_t = time.time()\n",
        "  xts, yts, xes, yes = next(ds_iter)\n",
        "\n",
        "  l = maml_step(maml_pfw, xts, yts, xes, yes, num_steps)\n",
        "  maml_loss_log.append(l)\n",
        "\n",
        "  if i % 50 == 0:\n",
        "    print(i)\n",
        "    print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "  if i % 500 == 0:\n",
        "    plt.plot(smoothen(maml_loss_log, 100), label='mp')\n",
        "    plt.yscale('log')\n",
        "    plt.legend()\n",
        "    plt.show()\n",
        "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "NKMDOGpORYyF"
      },
      "outputs": [],
      "source": [
        "eval_tot_steps = 100\n",
        "\n",
        "tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])\n",
        "ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0 step.\n",
        "\n",
        "for r in range(eval_tot_steps):\n",
        "  # We need to transform these into variables.\n",
        "  p_fw = maml_pfw\n",
        "\n",
        "  A, ph = ds_factory._create_task()\n",
        "\n",
        "  targets = A * np.sin(xrange_inputs + ph)\n",
        "\n",
        "  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "\n",
        "  # initial loss.\n",
        "  predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "  loss, _ = network.compute_loss(predictions, targets)\n",
        "  ev_losses[r, 0] = tf.reduce_mean(loss)\n",
        "\n",
        "  for i in range(NUM_STEPS):\n",
        "    p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)\n",
        "\n",
        "    # loss specific to only what we observe.\n",
        "    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "    predictions, _= network.forward(p_fw, x_observed_so_far)\n",
        "    loss, _ = network.compute_loss(predictions, y_observed_so_far)\n",
        "    tr_losses[r, i] = tf.reduce_mean(loss)\n",
        "\n",
        "    # Plotting for the continuous input range\n",
        "    predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "    loss, _ = network.compute_loss(predictions, targets)\n",
        "    ev_losses[r, i + 1] = tf.reduce_mean(loss)\n",
        "\n",
        "tr_losses_m = np.mean(tr_losses, axis=0)\n",
        "ev_losses_m = np.mean(ev_losses, axis=0)\n",
        "tr_losses_sd = np.std(tr_losses, axis=0)\n",
        "ev_losses_sd = np.std(ev_losses, axis=0)\n",
        "print(\"tr_l, m:\", tr_losses_m, \" sd:\", tr_losses_sd)\n",
        "print(\"ev_l, m:\", ev_losses_m, \" sd:\", ev_losses_sd)\n",
        "\n",
        "ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]\n",
        "plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)\n",
        "plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')\n",
        "\n",
        "ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]\n",
        "plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)\n",
        "plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')\n",
        "plt.ylim(0.0, 0.04)\n",
        "plt.xlabel(\"num steps\")\n",
        "plt.ylabel(\"L2 loss\")\n",
        "plt.legend()\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vmJj63OEO35M"
      },
      "outputs": [],
      "source": [
        "# Same task as MPLP fo same drawing\n",
        "\n",
        "p_fw = maml_pfw\n",
        "\n",
        "plot_every = 1\n",
        "\n",
        "predictions, _ = network.forward(p_fw, xrange_inputs)\n",
        "\n",
        "targets = A * np.sin(xrange_inputs + ph)\n",
        "plt.plot(xrange_inputs, targets, label='target')\n",
        "\n",
        "predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))\n",
        "\n",
        "xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "tr_losses = []\n",
        "ev_losses = []\n",
        "\n",
        "for i in range(NUM_STEPS):\n",
        "  p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)\n",
        "\n",
        "  # loss specific to only what we observe.\n",
        "  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "  predictions, _= network.forward(p_fw, x_observed_so_far)\n",
        "  loss, _ = network.compute_loss(predictions, y_observed_so_far)\n",
        "  tr_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "  # Plotting for the continuous input range\n",
        "  predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "  if (i+1) % plot_every == 0:\n",
        "    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))\n",
        "  loss, _ = network.compute_loss(predictions, targets)\n",
        "  ev_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "plt.legend()\n",
        "\n",
        "with open(\"tmp/maml_example_run.png\", \"wb\") as fout:\n",
        "  plt.savefig(fout)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "TdbjW_uplJ4H"
      },
      "outputs": [],
      "source": [
        "# @title Show an example run:\n",
        "n_plot = 5\n",
        "plot_every = max(1, NUM_STEPS // n_plot)\n",
        "\n",
        "p_fw = maml_pfw\n",
        "predictions, _ = network.forward(p_fw, xrange_inputs)\n",
        "\n",
        "A, ph = ds_factory._create_task()\n",
        "\n",
        "targets = A * np.sin(xrange_inputs + ph)\n",
        "plt.plot(xrange_inputs, targets, label='target')\n",
        "\n",
        "predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))\n",
        "\n",
        "xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)\n",
        "tr_losses = []\n",
        "ev_losses = []\n",
        "\n",
        "for i in range(NUM_STEPS):\n",
        "  p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)\n",
        "\n",
        "  # loss specific to only what we observe.\n",
        "  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))\n",
        "  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))\n",
        "  predictions, _= network.forward(p_fw, x_observed_so_far)\n",
        "  loss, _ = network.compute_loss(predictions, y_observed_so_far)\n",
        "  tr_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "  # Plotting for the continuous input range\n",
        "  predictions, _= network.forward(p_fw, xrange_inputs)\n",
        "  if (i+1) % plot_every == 0:\n",
        "    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))\n",
        "  loss, _ = network.compute_loss(predictions, targets)\n",
        "  ev_losses.append(tf.reduce_mean(loss))\n",
        "\n",
        "plt.legend()\n",
        "plt.show()\n",
        "\n",
        "plt.plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')\n",
        "plt.plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')\n",
        "plt.legend()\n",
        "plt.show()\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//experimental/selforg:selforg_colab",
        "kind": "private"
      },
      "name": "mplp_MAML_sinusoid_2x10.ipynb",
      "provenance": [
        {
          "file_id": "15XTc7xyCLONSR783k1AsAKCxwwYLu2eI",
          "timestamp": 1593621878830
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/self_organising_systems/mplp/notebooks/MAML_Sinusoid_L2L_2x10.ipynb?workspaceId=etr:twp::citc",
          "timestamp": 1593617126821
        },
        {
          "file_id": "19rXuMkLauVEXGFbg86cdvLpTi_IAaJ6q",
          "timestamp": 1593437725015
        },
        {
          "file_id": "13MWaLVCK11RJlFnmuz497CVfT2z_Mnze",
          "timestamp": 1589274038451
        },
        {
          "file_id": "1cb0W55ElZyZyCWcditlo9eLJ7VvmK3KX",
          "timestamp": 1587031457071
        },
        {
          "file_id": "1Id1-A7OS9ytTC52Y0T2S7mJjrwuuQ3jO",
          "timestamp": 1584949948813
        },
        {
          "file_id": "1bdTUOQlh7DvhGmH7lsmlYDC0y3k762jb",
          "timestamp": 1582372543431
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
